{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing packages:\n",
      "\t.package(path: \"/home/sgugger/git/course-v3/nbs/swift/FastaiNotebook_03_minibatch_training\")\n",
      "\t\tFastaiNotebook_03_minibatch_training\n",
      "With SwiftPM flags: []\n",
      "Working in: /tmp/tmpihf54r2k/swift-install\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "warning: /home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)[1/2] Compiling jupyterInstalledPackages jupyterInstalledPackages.swift\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "[2/3] Merging module jupyterInstalledPackages\n",
      "/home/sgugger/swift/usr/bin/swift: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift)\n",
      "/home/sgugger/swift/usr/bin/swiftc: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swiftc)\n",
      "/home/sgugger/swift/usr/bin/swift-autolink-extract: /home/sgugger/anaconda3/lib/libuuid.so.1: no version information available (required by /home/sgugger/swift/usr/bin/swift-autolink-extract)\n",
      "[3/3] Linking libjupyterInstalledPackages.so\n",
      "Initializing Swift...\n",
      "Installation complete!\n"
     ]
    }
   ],
   "source": [
    "%install-location $cwd/swift-install\n",
    "%install '.package(path: \"$cwd/FastaiNotebook_03_minibatch_training\")' FastaiNotebook_03_minibatch_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "import Path\n",
    "import TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import FastaiNotebook_03_minibatch_training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load our data and define a basic model like in the previous notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var (xTrain,yTrain,xValid,yValid) = loadMNIST(path: mnistPath, flat: true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60000 784 10\r\n"
     ]
    }
   ],
   "source": [
    "let (n,m) = (xTrain.shape[0],xTrain.shape[1])\n",
    "let c = yTrain.max().scalarized()+1\n",
    "print(n,m,c)\n",
    "let nHid = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public struct BasicModel: Layer {\n",
    "    public var layer1, layer2: FADense<Float>\n",
    "    \n",
    "    public init(nIn: Int, nHid: Int, nOut: Int){\n",
    "        layer1 = FADense(nIn, nHid, activation: relu)\n",
    "        layer2 = FADense(nHid, nOut)\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {\n",
    "        return layer2(layer1(input))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also directly define our model as an array of `FADense` layers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var model: [FADense<Float>] = [\n",
    "    FADense(m, nHid, activation: relu),\n",
    "    FADense(nHid, Int(c))] // BasicModel(nIn: m, nHid: nHid, nOut: Int(c))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset/DataBunch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We add our own wrapper above the S4TF `Dataset` for several reasons:\n",
    "- in S4TF, `Dataset` has no length and we need a `count` property to be able to do efficient hyper-parameters scheduling.\n",
    "- you can only apply `batched` once to a `Dataset` but we sometimes want to change the batch size. We save the original non-batched datasetin `innerDs`.\n",
    "- the shuffle needs to be called each time we want to reshuffle, so we make this happen in the compute property `ds`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export \n",
    "public struct FADataset<Element> where Element: TensorGroup {\n",
    "    public var innerDs: Dataset<Element>\n",
    "    public var shuffle = false\n",
    "    public var bs = 64 \n",
    "    public var dsCount: Int\n",
    "    \n",
    "    public var count: Int {\n",
    "        return dsCount%bs == 0 ? dsCount/bs : dsCount/bs+1\n",
    "    }\n",
    "    \n",
    "    public var ds: Dataset<Element> { \n",
    "        if !shuffle { return innerDs.batched(bs)}\n",
    "        let seed = Int64.random(in: Int64.min..<Int64.max)\n",
    "        return innerDs.shuffled(sampleCount: dsCount, randomSeed: seed).batched(bs)\n",
    "    }\n",
    "    \n",
    "    public init(_ ds: Dataset<Element>, len: Int, shuffle: Bool = false, bs: Int = 64) {\n",
    "        (self.innerDs,self.dsCount,self.shuffle,self.bs) = (ds, len, shuffle, bs)\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can define a `DataBunch` to group our training and validation datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public struct DataBunch<Element> where Element: TensorGroup{\n",
    "    public var train, valid: FADataset<Element>\n",
    "    \n",
    "    public init(train: Dataset<Element>, valid: Dataset<Element>, trainLen: Int, validLen: Int, bs: Int = 64) {\n",
    "        self.train = FADataset(train, len: trainLen, shuffle: true,  bs: bs)\n",
    "        self.valid = FADataset(valid, len: validLen, shuffle: false, bs: 2*bs)\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And add a convenience function to get MNIST in a `DataBunch` directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public func mnistDataBunch(path: Path = mnistPath, flat: Bool = false, bs: Int = 64)\n",
    "   -> DataBunch<DataBatch<TF, TI>> {\n",
    "    let (xTrain,yTrain,xValid,yValid) = loadMNIST(path: path, flat: flat)\n",
    "    return DataBunch(train: Dataset(elements: DataBatch(xb:xTrain, yb: yTrain)), \n",
    "                     valid: Dataset(elements: DataBatch(xb:xValid, yb: yValid)),\n",
    "                     trainLen: xTrain.shape[0],\n",
    "                     validLen: xValid.shape[0],\n",
    "                     bs: bs)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let data = mnistDataBunch(flat: true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "938\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train.count"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shuffle test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Timing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public extension Sequence {\n",
    "  func first() -> Element? {\n",
    "    return first(where: {_ in true})\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average: 273.9586643 ms,   min: 270.49003 ms,   max: 292.917936 ms\r\n"
     ]
    }
   ],
   "source": [
    "time(repeating: 10) {\n",
    "  let tst = data.train.ds\n",
    "\n",
    "  tst.first()!.yb\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check we get different batches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[9, 4, 7, 5, 3, 4, 2, 3, 2, 8, 7, 8, 9, 0, 4, 1, 1, 2, 2, 2, 7, 6, 6, 3, 1, 1, 7, 9, 1, 2, 6, 4, 4, 4, 7, 0, 6, 6, 8, 9, 2, 2, 2, 8, 4, 2, 9, 3, 8, 1, 9, 8, 8, 5, 6, 0, 9, 6, 1, 7, 8, 2, 4, 2]\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var tst = data.train.ds\n",
    "tst.first()!.yb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[6, 0, 4, 6, 5, 8, 4, 1, 1, 9, 0, 8, 7, 8, 7, 9, 0, 8, 0, 6, 0, 3, 6, 1, 3, 8, 7, 9, 4, 8, 4, 4, 4, 8, 0, 8, 5, 5, 8, 9, 7, 5, 4, 8, 3, 9, 1, 2, 5, 1, 7, 8, 1, 7, 8, 9, 5, 2, 8, 6, 4, 0, 3, 8]\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tst = data.train.ds\n",
    "tst.first()!.yb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `Learner`,  `LearnerAction`: enums and error handling in Swift, oh my!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just like in Python, we'll use \"exception handling\" to let custom actions indicate that they want to stop, skip over a batch or do other custom processing - e.g. for early stopping.\n",
    "\n",
    "We'll start by defining a custom type to represent the stop reason, and we'll use a Swift enum to describe it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public enum LearnerAction: Error {\n",
    "    case skipEpoch(reason: String)\n",
    "    case skipBatch(reason: String)\n",
    "    case stop(reason: String)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now this a bit of an unusual thing - we have met protocols before, and `: Error` is a protocol that `LearnerAction` conforms to, but what is going on with those cases?\n",
    "\n",
    "Let's jump briefly into slides to talk about Swift enums:\n",
    "\n",
    "**Slides:** [Supercharged Enums in Swift](https://docs.google.com/presentation/d/1dc6o2o-uYGnJeCeyvgsgyk05dBMneArxdICW5vF75oU/edit#slide=id.g512a2e238a_144_147)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Basic `Learner` class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "/// Initializes and trains a model on a given dataset.\n",
    "public final class Learner<Label: TensorGroup,\n",
    "                           Opt: TensorFlow.Optimizer & AnyObject>\n",
    "    where Opt.Scalar: Differentiable,\n",
    "          Opt.Model: Layer,\n",
    "          // Constrain model input to Tensor<Float>, to work around\n",
    "          // https://forums.fast.ai/t/fix-ad-crash-in-learner/42970.\n",
    "          Opt.Model.Input == Tensor<Float>\n",
    "{\n",
    "    public typealias Model = Opt.Model\n",
    "    public typealias Input = Model.Input\n",
    "    public typealias Output = Model.Output\n",
    "    public typealias Data = DataBunch<DataBatch<Input, Label>>\n",
    "    public typealias Loss = TF\n",
    "    public typealias Optimizer = Opt\n",
    "    public typealias Variables = Model.AllDifferentiableVariables\n",
    "    public typealias EventHandler = (Learner) throws -> Void\n",
    "    \n",
    "    /// A wrapper class to hold the loss function, to work around\n",
    "    // https://forums.fast.ai/t/fix-ad-crash-in-learner/42970.\n",
    "    public final class LossFunction {\n",
    "        public typealias F = @differentiable (Model.Output, @nondiff Label) -> Loss\n",
    "        public var f: F\n",
    "        init(_ f: @escaping F) { self.f = f }\n",
    "    }\n",
    "    \n",
    "    public var data: Data\n",
    "    public var opt: Optimizer\n",
    "    public var lossFunc: LossFunction\n",
    "    public var model: Model\n",
    "    \n",
    "    public var currentInput: Input!\n",
    "    public var currentTarget: Label!\n",
    "    public var currentOutput: Output!\n",
    "    \n",
    "    public private(set) var epochCount = 0\n",
    "    public private(set) var currentEpoch = 0\n",
    "    public private(set) var currentGradient = Model.TangentVector.zero\n",
    "    public private(set) var currentLoss = Loss.zero\n",
    "    public private(set) var inTrain = false\n",
    "    public private(set) var pctEpochs = Float.zero\n",
    "    public private(set) var currentIter = 0\n",
    "    public private(set) var iterCount = 0\n",
    "    \n",
    "    open class Delegate {\n",
    "        open var order: Int { return 0 }\n",
    "        public init () {}\n",
    "        \n",
    "        open func trainingWillStart(learner: Learner) throws {}\n",
    "        open func trainingDidFinish(learner: Learner) throws {}\n",
    "        open func epochWillStart(learner: Learner) throws {}\n",
    "        open func epochDidFinish(learner: Learner) throws {}\n",
    "        open func validationWillStart(learner: Learner) throws {}\n",
    "        open func batchWillStart(learner: Learner) throws {}\n",
    "        open func batchDidFinish(learner: Learner) throws {}\n",
    "        open func didProduceNewGradient(learner: Learner) throws {}\n",
    "        open func optimizerDidUpdate(learner: Learner) throws {}\n",
    "        open func batchSkipped(learner: Learner, reason:String) throws {}\n",
    "        open func epochSkipped(learner: Learner, reason:String) throws {}\n",
    "        open func trainingStopped(learner: Learner, reason:String) throws {}\n",
    "        ///\n",
    "        /// TODO: learnerDidProduceNewOutput and learnerDidProduceNewLoss need to\n",
    "        /// be differentiable once we can have the loss function inside the Learner\n",
    "    }\n",
    "    \n",
    "    public var delegates: [Delegate] = [] {\n",
    "        didSet { delegates.sort { $0.order < $1.order } }\n",
    "    }\n",
    "    \n",
    "    public init(data: Data, lossFunc: @escaping LossFunction.F,\n",
    "                optFunc: (Model) -> Optimizer, modelInit: ()->Model) {\n",
    "        (self.data,self.lossFunc) = (data,LossFunction(lossFunc))\n",
    "        model = modelInit()\n",
    "        opt = optFunc(self.model)\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then let's write the parts of the training loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "extension Learner {\n",
    "    private func evaluate(onBatch batch: DataBatch<Input, Label>) throws {\n",
    "        currentOutput = model(currentInput)\n",
    "        currentLoss = lossFunc.f(currentOutput, currentTarget)\n",
    "    }\n",
    "    \n",
    "    private func train(onBatch batch: DataBatch<Input, Label>) throws {\n",
    "        let (xb,yb) = (currentInput!,currentTarget!) //We still have to force-unwrap those for AD...\n",
    "        (currentLoss, currentGradient) = model.valueWithGradient { model -> Loss in \n",
    "            let y = model(xb)                                      \n",
    "            self.currentOutput = y\n",
    "            return self.lossFunc.f(y, yb)\n",
    "        }\n",
    "        for d in delegates { try d.didProduceNewGradient(learner: self) }\n",
    "        opt.update(&model.variables, along: self.currentGradient)\n",
    "    }\n",
    "    \n",
    "    private func train(onDataset ds: FADataset<DataBatch<Input, Label>>) throws {\n",
    "        iterCount = ds.count\n",
    "        for batch in ds.ds {\n",
    "            (currentInput, currentTarget) = (batch.xb, batch.yb)\n",
    "            do {\n",
    "                for d in delegates { try d.batchWillStart(learner: self) }\n",
    "                if inTrain { try train(onBatch: batch) } else { try evaluate(onBatch: batch) }\n",
    "            }\n",
    "            catch LearnerAction.skipBatch(let reason) {\n",
    "                for d in delegates {try d.batchSkipped(learner: self, reason:reason)}\n",
    "            }\n",
    "            for d in delegates { try d.batchDidFinish(learner: self) }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And the whole fit function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "extension Learner {\n",
    "    /// Starts fitting.\n",
    "    /// - Parameter epochCount: The number of epochs that will be run.\n",
    "    public func fit(_ epochCount: Int) throws {\n",
    "        self.epochCount = epochCount\n",
    "        do {\n",
    "            for d in delegates { try d.trainingWillStart(learner: self) }\n",
    "            for i in 0..<epochCount {\n",
    "                self.currentEpoch = i\n",
    "                do {\n",
    "                    for d in delegates { try d.epochWillStart(learner: self) }\n",
    "                    try train(onDataset: data.train)\n",
    "                    for d in delegates { try d.validationWillStart(learner: self) }\n",
    "                    try train(onDataset: data.valid)\n",
    "                } catch LearnerAction.skipEpoch(let reason) {\n",
    "                    for d in delegates {try d.epochSkipped(learner: self, reason:reason)}\n",
    "                }\n",
    "                for d in delegates { try d.epochDidFinish(learner: self) }\n",
    "            }\n",
    "        } catch LearnerAction.stop(let reason) {\n",
    "            for d in delegates {try d.trainingStopped(learner: self, reason:reason)}\n",
    "        }\n",
    "\n",
    "        for d in delegates { try d.trainingDidFinish(learner: self) }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func optFunc(_ model: BasicModel) ->  SGD<BasicModel> { return SGD(for: model, learningRate: 1e-2)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func modelInit() -> BasicModel {return BasicModel(nIn: m, nHid: nHid, nOut: Int(c))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.fit(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's add Callbacks!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extension with convenience methods to add delegates:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public extension Learner {\n",
    "    func addDelegate (_ delegate :  Learner.Delegate ) { delegates.append(delegate) }\n",
    "    func addDelegates(_ delegates: [Learner.Delegate]) { self.delegates += delegates }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train/eval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Callback classes are defined as extensions of the Learner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "extension Learner {\n",
    "    public class TrainEvalDelegate: Delegate {\n",
    "        public override func trainingWillStart(learner: Learner) {\n",
    "            learner.pctEpochs = 0.0\n",
    "        }\n",
    "\n",
    "        public override func epochWillStart(learner: Learner) {\n",
    "            Context.local.learningPhase = .training\n",
    "            (learner.pctEpochs,learner.inTrain,learner.currentIter) = (Float(learner.currentEpoch),true,0)\n",
    "        }\n",
    "        \n",
    "        public override func batchDidFinish(learner: Learner) {\n",
    "            learner.currentIter += 1\n",
    "            if learner.inTrain{ learner.pctEpochs += 1.0 / Float(learner.iterCount) }\n",
    "        }\n",
    "        \n",
    "        public override func validationWillStart(learner: Learner) {\n",
    "            Context.local.learningPhase = .inference\n",
    "            learner.inTrain = false\n",
    "            learner.currentIter = 0\n",
    "        }\n",
    "    }\n",
    "    \n",
    "    public func makeTrainEvalDelegate() -> TrainEvalDelegate { return TrainEvalDelegate() }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.delegates = [learner.makeTrainEvalDelegate()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.fit(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AverageMetric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "extension Learner {\n",
    "    public class AvgMetric: Delegate {\n",
    "        public let metrics: [(Output, Label) -> TF]\n",
    "        var total: Int = 0\n",
    "        var partials = [TF]()\n",
    "        \n",
    "        public init(metrics: [(Output, Label) -> TF]) { self.metrics = metrics}\n",
    "        \n",
    "        public override func epochWillStart(learner: Learner) {\n",
    "            total = 0\n",
    "            partials = Array(repeating: Tensor(0), count: metrics.count + 1)\n",
    "        }\n",
    "        \n",
    "        public override func batchDidFinish(learner: Learner) {\n",
    "            if !learner.inTrain{\n",
    "                let bs = learner.currentInput!.shape[0] //Possible because Input is TF for now\n",
    "                total += bs\n",
    "                partials[0] += Float(bs) * learner.currentLoss\n",
    "                for i in 1...metrics.count{\n",
    "                    partials[i] += Float(bs) * metrics[i-1](learner.currentOutput!, learner.currentTarget!)\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "        \n",
    "        public override func epochDidFinish(learner: Learner) {\n",
    "            for i in 0...metrics.count {partials[i] = partials[i] / Float(total)}\n",
    "            print(\"Epoch \\(learner.currentEpoch): \\(partials)\")\n",
    "        }\n",
    "    }\n",
    "    \n",
    "    public func makeAvgMetric(metrics: [(Output, Label) -> TF]) -> AvgMetric{\n",
    "        return AvgMetric(metrics: metrics)\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeAvgMetric(metrics: [accuracy])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: [0.50405025, 0.8781]\n",
      "Epoch 1: [0.3750386, 0.9004]\n"
     ]
    }
   ],
   "source": [
    "learner.fit(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "extension Learner {\n",
    "    public class Normalize: Delegate {\n",
    "        public let mean, std: TF\n",
    "        public init(mean: TF, std: TF) { (self.mean,self.std) = (mean,std) }\n",
    "        \n",
    "        public override func batchWillStart(learner: Learner) {\n",
    "            learner.currentInput = (learner.currentInput! - mean) / std\n",
    "        }\n",
    "    }\n",
    "    \n",
    "    public func makeNormalize(mean: TF, std: TF) -> Normalize{\n",
    "        return Normalize(mean: mean, std: std)\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "▿ 2 elements\n",
       "  - mean : 0.13066049\n",
       "  - std : 0.30810767\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(mean: xTrain.mean(), std: xTrain.standardDeviation())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// export\n",
    "public let mnistStats = (mean: TF(0.13066047), std: TF(0.3081079))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let learner = Learner(data: data, lossFunc: softmaxCrossEntropy, optFunc: optFunc, modelInit: modelInit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.delegates = [learner.makeTrainEvalDelegate(), learner.makeAvgMetric(metrics: [accuracy]),\n",
    "                     learner.makeNormalize(mean: mnistStats.mean, std: mnistStats.std)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: [0.3052104, 0.911]\n",
      "Epoch 1: [0.25273553, 0.9291]\n"
     ]
    }
   ],
   "source": [
    "learner.fit(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "success\r\n"
     ]
    }
   ],
   "source": [
    "import NotebookExport\n",
    "let exporter = NotebookExport(Path.cwd/\"04_callbacks.ipynb\")\n",
    "print(exporter.export(usingPrefix: \"FastaiNotebook_\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Swift",
   "language": "swift",
   "name": "swift"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
